using PyPlot, DifferentialEquations, LaTeXStrings, Random, DelimitedFiles, NPZ, LinearAlgebra
include("Lib2.jl")
println(" ")

# Function for indexing variables.
function makeIndexTableV14(Ne)

    x, y, z = 1:3

    tbl = zeros(3,Ne)

    cnt = 1

    # x
    for i=1:Ne
        tbl[x,i] = cnt
        cnt += 1
    end

    # y
    for i=1:Ne
        tbl[y,i] = cnt
        cnt += 1
    end

    # z
    for i=1:Ne
        tbl[z,i] = cnt
        cnt += 1
    end

    return Int.(tbl)
end

function main()

    # Main Parameters
    Ne = 15                      # Number of spin ensembles
    M  = 10000; S=M/2.0              # Number of atoms per ensemble
    ωc = 2*π * 80                # Mode detuning MHz
    κ =  2*π * 0.13#13               # Cavity loss MHz
    ωz0 = 2*π * 0.015              # Atomic frequency MHz

    # Ramp functions
    rampTime = 2000.0 #2000.0           # us
    rampFactor = 5 #5        # Power, g^2/gc^2
    offset = 0.0
    f(t) = smoothstep((t-100)/600) # This function needs to go to 1.
    gNorm(t) = sqrt(rampFactor*f(t)) #sqrt(max(0,rampFactor * f(t) )) # Dimensionless
    ωz(t) = ωz0 * max( 0, 1 - (1-offset)*smoothstep((t-100)/600)  )

    # Time step for noise
    dt = 1

    # Measurement related
    numTvalsObs = 1000

    # Set the seed for random number generation
    randSeed = abs(rand(Int16)) 
    Random.seed!(randSeed)
    println("Seed: "*string(randSeed))

    ############################################
     # Construct a J matrix
    ############################################

    # Random SK J matrix
    # J = randn(Ne,Ne)
    # J = (J+J')./2

    # Generate a random confocal J matrix
    pos, J = confocalJrealHalfPlane(Ne,s=0.0,ϕ=0.02,w=2)

    # Load a J matrix
    # Jidx = 1
    # J = npzread("JsN15_w2.npz")[:,:,Jidx]

    GS,E0 = localMins(J-diagm(diag(J)))

    ############################################
     # Analyze J matrix and make pump strength function
    ############################################

    # Compute eigendecomposition and make sure J is PSD
    Jeig = eigen(J)
    Jevals, Jevecs = Jeig.values, Jeig.vectors #eigenTrunc(K,tol=tol)
    if minimum(Jevals)<0
        println("Warning: J is not PSD, with eigenvalue "*string(minimum(Jevals)))
        minVal = minimum(Jevals)
        for i=1:Ne
            J[i,i] = J[i,i]-minVal
        end
        Jeig = eigen(J)
        Jevals, Jevecs = Jeig.values, Jeig.vectors
        Jevals = abs.(Jevals)
    end
    Jnon = copy(J)
    for i=1:Ne
        Jnon[i,i] = 0
    end

    # Compute mode couplings
    α = zeros(Ne,Ne)
    for i=1:Ne
        for m=1:Ne
            α[i,m] = sqrt(Jevals[m])*Jevecs[i,m]
        end
    end
    # Compute critical pump strength
    λmax = maximum(eigen(J).values) # This is supposed to be with Jκ but doesnt seem to match?
    gc = sqrt((ωc^2+κ^2)/ωc * ωz0 /(M*λmax))
    println("Critical g: "*string(gc))

    g(t) = gc*gNorm(t)

    # Find tc for semiclassical dynamics
    tc=0
    for i=1:10000
        tval = i/10000*rampTime
        if g(tval)^2 / gc^2 >= ωz(tval)/ωz0
            tc=tval
            break
        end
    end
    # println(tc)

    tspan = (0,rampTime)

    # Compute parameters
    Jmod = zeros(Ne,Ne)
    K    = complex.(zeros(Ne,Ne))
    Kr   = zeros(Ne,Ne)
    Ki   = zeros(Ne,Ne)
    γx   = complex.(zeros(Ne,Ne))
    γy   = complex.(zeros(Ne,Ne))
    Γxx  = complex.(im*γx*γx')
    Γyy  = complex.(im*γy*γy')
    Γxy  = complex.(im*conj.(γx)*conj.(γy)')
    Γyx  = complex.(im*conj.(γy)*conj.(γx)')
    ImΓxx = imag.(Γxx)
    ImΓyy = imag.(Γyy)
    ImΓxy = imag.(Γxy)
    ImΓyx = imag.(Γyx)
    ReΓxx = real.(Γxx)
    ReΓyy = real.(Γyy)
    ReΓxy = real.(Γxy)
    ReΓyx = real.(Γyx)
    function computeCoefs(t)

        ωzt = ωz(t)
        gt  = g(t)*2 
        gt2 = gt^2

        # Compute modified J and K for Morigi theory.
        Jmod .= gt2 * real.( (α*α')*( 1/(ωc+ωzt-im*κ) + 1/(ωc-ωzt-im*κ) )/2 )
        K .= gt2 * im*(α*α')*( 1/(ωc+ωzt-im*κ) - 1/(ωc-ωzt-im*κ) )/2
        Kr .= real.(K)
        Ki .= imag.(K)

        γx .=    gt * sqrt(κ/2)*α*( 1/(ωc+ωzt-im*κ) + 1/(ωc-ωzt-im*κ) )
        γy .= gt * im*sqrt(κ/2)*α*( 1/(ωc+ωzt-im*κ) - 1/(ωc-ωzt-im*κ) )

        Γxx .= im*γx*γx'
        Γyy .= im*γy*γy'
        Γxy .= im*conj.(γx)*conj.(γy)'
        Γyx .= im*conj.(γy)*conj.(γx)'
        ImΓxx .= imag.(Γxx)
        ImΓyy .= imag.(Γyy)
        ImΓxy .= imag.(Γxy)
        ImΓyx .= imag.(Γyx)
        ReΓxx .= real.(Γxx)
        ReΓyy .= real.(Γyy)
        ReΓxy .= real.(Γxy)
        ReΓyx .= real.(Γyx)
    end

    ############################################
     # Find ground state ground energies
    ############################################

    # Compute energy of binarized lowest energy eigenvector
    binEvec = sign.(Jevecs[:,length(Jevals)])
    EminBin = -binEvec' * Jnon * binEvec

    # Do steepest descent on binarized eigenvector
    binEvecSD = copy(binEvec)
    SD!(binEvecSD,Jnon)
    E_SD = -binEvecSD' * Jnon * binEvecSD

    ############################################
     # Initial conditions and other variables
    ############################################

    # indices used
    x, y, z = 1:3
    tb = makeIndexTableV14(Ne)
    totalSize = maximum(tb)

    u0 = zeros(totalSize)

    # Starting in normal phase
    for i=1:Ne
        u0[tb[z,i]] = -S
    end

    # Starting with some noise
    # for i=1:Ne
    #     θ0 = π*randn()*0.02 *0
    #     ϕ0 = 2π*rand()

    #     u0[tb[x,i]] =  S*sin(θ0)*cos(ϕ0)
    #     u0[tb[y,i]] =  S*sin(θ0)*sin(ϕ0)
    #     u0[tb[z,i]] = -S*cos(θ0)
    # end

    ##############################
     # Derivative functions
    ##############################

    function QSD_MF_deriv!(du,u,p,t)

        gt = g(t)
        gt2 = gt^2
        ωzt = ωz(t)

        # Update coefficients
        computeCoefs(t)

        @inbounds for k=1:Ne
            du[tb[x,k]] = -(ωzt + 0.5*Ki[k,k] - 0.5*ImΓxy[k,k])*u[tb[y,k]] - 0.5*ImΓyy[k,k]*u[tb[x,k]]

            du[tb[y,k]] =  (ωzt + 0.5*Ki[k,k] + 0.5*ImΓyx[k,k])*u[tb[x,k]] - 0.5*ImΓxx[k,k]*u[tb[y,k]]

            du[tb[z,k]] = -0.5*( ImΓxx[k,k] - ImΓyy[k,k] )*u[tb[z,k]]

            @inbounds for i=1:Ne
                du[tb[x,k]] += - u[tb[z,k]]*( (Kr[i,k]-ReΓxy[i,k])*u[tb[x,i]] - ReΓyy[i,k]*u[tb[y,i]] )

                du[tb[y,k]] += u[tb[z,k]]*( (Kr[i,k]-ReΓyx[i,k])*u[tb[y,i]] + (2*Jmod[i,k]-ReΓxx[i,k])*u[tb[x,i]] )

                du[tb[z,k]] += (Kr[i,k]-ReΓxy[i,k])*u[tb[x,i]]*u[tb[x,k]] + (ReΓxx[i,k]-2*Jmod[i,k])*u[tb[x,i]]*u[tb[y,k]]
                du[tb[z,k]] += (ReΓyx[i,k]-Kr[i,k])*u[tb[y,i]]*u[tb[y,k]] - ReΓyy[i,k]*u[tb[y,i]]*u[tb[x,k]]
            end
        end
    end

    ##############################
     # Stochastic parts
    ##############################

    ξ = complex.(zeros(Ne))

    function StochasticStep!(integrator,ξ,sqrtdt)

        # ξ = randn(Ne) .* exp.(im*2π*rand(Ne))
        ΔW = sqrtdt*ξ

        # Update coefficients
        computeCoefs(integrator.t)

        ΔWx = imag.(γx*ΔW)
        ΔWy = imag.(γy*ΔW)

        @inbounds for k=1:Ne
            integrator.u[tb[x,k]] +=  ΔWy[k]*integrator.u[tb[z,k]]
            integrator.u[tb[y,k]] += -ΔWx[k]*integrator.u[tb[z,k]]
            integrator.u[tb[z,k]] +=  ΔWx[k]*integrator.u[tb[y,k]] - ΔWy[k]*integrator.u[tb[x,k]]
        end

    end

    ############################################
    # Compute the trajectory
    ############################################

    # Solver otpions
    solver = Tsit5()
    abstol = 1e-8 
    reltol = 1e-8

    # Derived quantities
    sqrtdt = sqrt(dt)
    numSteps = Int(ceil(tspan[2]/dt))
    numTvalsObs = min(numSteps,numTvalsObs)
    stepsPerSave = Int(floor(numSteps/(numTvalsObs-1.0)))

    # Observables
    t_,E_ = zeros(numTvalsObs), zeros(numTvalsObs)
    x_, y_, z_, ξ_ = zeros(numTvalsObs,Ne), zeros(numTvalsObs,Ne), zeros(numTvalsObs,Ne), complex.(zeros(numTvalsObs,Ne))

    # Make the problem instance
    prob = ODEProblem(QSD_MF_deriv!,u0,tspan)

    # Solve the ODE
    println("Solving trajectory... ")
    start=time()
    integrator = init(prob,solver,save_on=false,abstol=abstol,reltol=reltol)
    saveIdx = 1
    for step=1:numSteps

        # Deterministic step
        step!(integrator,dt,true)

        # Stochastic step
        ξ .= randn(Ne) .* exp.(im*2π*rand(Ne))
        StochasticStep!(integrator,ξ,sqrtdt)

        # Saving / measuring
        if mod(step-1,stepsPerSave)==0 && saveIdx <= numTvalsObs

            @inbounds for i=1:Ne
                x_[saveIdx,i] = integrator.u[tb[x,i]]
                y_[saveIdx,i] = integrator.u[tb[y,i]]
                z_[saveIdx,i] = integrator.u[tb[z,i]]
                ξ_[saveIdx,i] = ξ[i]
            end
            t_[saveIdx] = integrator.t

            # Compute energy
            computeCoefs(integrator.t)
            ωzt = ωz(integrator.t)

            E_[saveIdx] = ωzt*sum(z_[saveIdx,:]) - dot(x_[saveIdx,:],Jmod,x_[saveIdx,:]) - dot(x_[saveIdx,:],Kr,y_[saveIdx,:])
            E_[saveIdx] /= M

            saveIdx += 1
        end
    end
    print(round((time()-start)/60,digits=1));println(" (m)")

    expVals = (t_,x_,y_,z_,ξ_,E_)

    return integrator, expVals, Ne, S, g, ωz, tc, J, E_SD, E0, ωc, κ, α
end

# Do it
integrator, expVals, Ne, S, g, ωz, tc, J, E_SD, E0, ωc, κ, α = main()
t_,x_,y_,z_,ξ_,E_ = expVals
M = Int(2*S)

numTvalsObs = length(t_)
S = M/2.0

#############################
 # Compute interesting quantities
#############################

# Compute Ising energy from phase of light
E_Ising = zeros(numTvalsObs)
E_IsingBin = zeros(numTvalsObs)
for i=1:Ne
    for j=1:Ne
        if i==j
            continue
        end
        E_Ising .+= -J[i,j]*x_[:,i].*x_[:,j]/S^2
        E_IsingBin .+= -J[i,j]*sign.(x_[:,i]).*sign.(x_[:,j])
    end
end

# Compute measurement record
dIrec = zeros(numTvalsObs,Ne)
dQrec = zeros(numTvalsObs,Ne)
dtSave = t_[2] - t_[1]
for m=1:Ne
    dIrec[:,m] .= sqrt(2*κ*dtSave)*real.(ξ_[:,m])
    dQrec[:,m] .= sqrt(2*κ*dtSave)*imag.(ξ_[:,m])
    for i=1:Ne
        dIrec[:,m] .+= 2*κ*dtSave*(2*real.( -α[i,m]*g.(t_).*x_[:,i]/(ωc-im*κ) ))
        dQrec[:,m] .+= 2*κ*dtSave*(2*imag.( -α[i,m]*g.(t_).*x_[:,i]/(ωc-im*κ) ))
    end
end
Ereal = zeros(numTvalsObs,Ne)
Eimag = zeros(numTvalsObs,Ne)
for i=1:Ne
    for m=1:Ne
        Ereal[:,i] .+= α[i,m]*dIrec[:,m]
        Eimag[:,i] .+= α[i,m]*dQrec[:,m]
    end
end
ErealCum = zeros(numTvalsObs,Ne)
EimagCum = zeros(numTvalsObs,Ne)
for i=1:Ne
    ErealCum[:,i] = cumsum(Ereal[:,i])*dtSave
    EimagCum[:,i] = cumsum(Eimag[:,i])*dtSave
end

# A function to get Ising states
function idx2state(idx,M,Ne)
    return Int.(ones(Ne) .- 2*digits(idx-1, base=M+1, pad=Ne))
end

############################
 # Plot results
############################

colors0 = colors15()

# X
fig = figure()
xlabel("Time (us)",fontsize=12)
for i=1:Ne
    plot(t_,x_[:,i]./S,"-",label=string(i),linewidth=1.15,color=colors0[i],alpha=1,zorder=100-i)
end
plot(t_,(g.(t_)/g(t_[end])).^2,"--",color="black",alpha=1,label="Pump",zorder=100)
plot(t_,ωz.(t_)/ωz(0),"--",color="red",alpha=1,label=L"\omega_z",zorder=100)
axvline(tc,color="tab:blue",linestyle="--",alpha=1,label=L"t_c")
ylim(-1.05,1.05)
grid()
ylabel(L"\langle S_i^x \rangle /S",fontsize=12)
legend(loc=(1.05,0.0),ncol=2)
PyPlot.display_figs()

